import os
import re
import json
import torch
import textwrap
from rich import print
from openai import OpenAI
from policies.utils.transition import Transition
from policies import prompts

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
delimiter = "\n--------\n"

class Reflection:
    def __init__(self, model=None,
                 tokenizer=None,
                 device="cuda",
                 temperature=0.2,
                 comm_only=False,
                 control_only=False,
                 is_focal=False,
                 ):
        self.model = model
        if type(model) == str and "gpt" in model:
            self.model_type = "gpt"
            self.openai_api_key = OPENAI_API_KEY
        else:
            self.model_type = "llama"
            self.openai_api_key = None
        self.tokenizer = tokenizer
        if tokenizer:
            self.terminators = [
                                tokenizer.eos_token_id,
                                tokenizer.convert_tokens_to_ids("<|eot_id|>")
                                ]
        self.device = torch.device(device if torch.cuda.is_available() else "cpu")
        self.temperature = temperature
        self.learned_knowledge = None
        self.cooperative_knowledge = None
        self.comm_only = comm_only
        self.control_only = control_only
        self.is_focal = is_focal

    def reflect(self, batch):
        # batch is a list of Transition objects
        self.learned_knowledge = self.prompting_tip(batch)
        return self.learned_knowledge
    
    def correct(self, transition):
        correction = self.prompting_correction(transition)
        transition = self.parse_correction(correction, transition)
        return transition

    def chat(self, prompt, json_format=False):
        if self.model_type == "gpt":
            return self.chat_gpt(prompt, json_format)
        else:
            return self.chat_llama(prompt, json_format)

    def chat_llama(self, prompt, json_format=False):
        prompt = self.tokenizer.apply_chat_template(prompt,
                                                    add_generation_prompt=True,
                                                    return_tensors='pt').to(self.device)
        with torch.no_grad():
            outputs = self.model.generate(prompt,
                                          max_length=4096,
                                          temperature=self.temperature,
                                          eos_token_id=self.terminators
                                          )
        response = self.tokenizer.decode(outputs[0][prompt.shape[-1]:], skip_special_tokens=True)
        return response

    def chat_gpt(self, prompt, json_format=False):
        client = OpenAI(api_key=self.openai_api_key)
        response = client.chat.completions.create(
            model=self.model,
            messages=prompt,
            temperature=self.temperature,
            stream=False,
            response_format={"type": "json_object"} if json_format else None
            )
        response = response.choices[0].message.content.strip()
        return response

    def prompting_tip(self, batch):
        vehicle_id = batch[0].vehicle_id
        task = batch[0].obs['task']
        agent_vehicle_mapping = batch[0].feedback['agent_vehicle_mapping']

        SYSTEM_MSG = """\
        You are reflecting your past experience to learn from it. Here are some tips for the reflection session:
        1. Sign of the Lane ID indicate the direction of the lane.
        2. You and other vehicles may not be able to see all the vehicles in the scene because of possible occlusions.
        3. Your received messages are from other vehicles.
        4. Keep in mind what your task is and complete it as fast as possible.
        5. Vehicle ID(usually a number) are not allowed to be mentioned in the strategy and knowledge, but you can describe them without Vehicle ID.
        """
        SYSTEM_MSG = textwrap.dedent(SYSTEM_MSG)
        if not self.comm_only:
            SYSTEM_MSG += textwrap.dedent("""\
            6. You have a predefined route to follow, so you just need to adjust your speed and follow the route.
            7. Only the following commands are available for agents now: 1. stop (slow down the vehicle) 2. go (keep speed and follow the planned route) 3. speed up (accelerate the vehicle)
            """)
        prompt = [{"role": "system", "content": SYSTEM_MSG}]

        PROMPT1 = f"You are Vehicle {vehicle_id}. Your task was {task} Here are some observation and the action you took along with the consequences:"
        prompt.append({"role": "user", "content": PROMPT1})
        for transition in batch:
            PROMPT2 = ""
            # Describe the observation and action
            obs = transition.obs
            if obs is None:
                continue
            obs_message = get_obs_message(obs)
            PROMPT2 += obs_message
            
            # Describe the action
            action_message = f"Action: "
            if "reasoning" in transition.action.keys():
                action_message += f"Your reasoning was: {transition.action['reasoning']}\n"
            if "message" in transition.action.keys():
                action_message += f"Your generated message was: {transition.action['message']}\n"
            if "command" in transition.action.keys():
                action_message += f"Your driving command was: {transition.action['command']}\n"
            PROMPT2 += action_message
            PROMPT2 += delimiter

            # Describing consequences
            if transition.feedback["success"]:
                if self.is_focal:
                    consequence = "You successfully completed your task."
                else:
                    consequence = "You successfully helped others to complete their task."
            else:
                if transition.feedback["collision_info"]["collision_occurred"]:
                    if self.is_focal:
                        reason = transition.feedback["collision_info"]["collision_description"] + f"in {transition.time_to_collision} seconds. Note that you are Vehicle {vehicle_id}. "
                    else:
                        reason = transition.feedback["collision_info"]["collision_description"] + f"in {transition.time_to_collision} seconds. You failed at helping them avoid the collision. "
                elif transition.feedback["stagnation_info"]["stagnation_occurred"]:
                    if self.is_focal:
                        reason = "You stagnated for too long and ran out of time to complete your task."
                    else:
                        reason = "Someone stagnated for too long and ran out of time to complete their task."
                consequence = f"You failed to complete the task, because {reason}"
            PROMPT2 += consequence
            prompt.append({"role": "user", "content": PROMPT2})
        analysis_question = """\
              Analyze the consequences of your action given observations.
              Since your sensors are limited, you may not be able to see all the vehicles in the scene because of possible occlusions, if any, find those occluded vehicles through discussion.
              If there is any vehicle causes potential threat or situation that you think is accident-prone or inefficient (causes stagnation), please provide a reason why you think so.
              """
        prompt.append({"role": "user", "content": textwrap.dedent(analysis_question)})
        
        # reasoning step
        analysis = self.chat(prompt)
        prompt.append({"role": "assistant", "content": analysis})

        # learning knowledge for yourself
        if self.learned_knowledge:
            PROMPT5 = f"Here was the knowledge about completing your task from your past experience: {self.learned_knowledge} \n"
            PROMPT5 += "Revise the knowledge for yourself to keep in mind for future driving, based on your analysis and proposed strategy.\n"
        else:
            PROMPT5 = "Based on your analysis, summarize a knowledge for yourself to keep in mind for future driving or completing your tasks.\n"
        PROMPT5 += "Note that you are not allowed to mention Vehicle ID in the knoweldge because they can change."
        prompt.append({"role": "user", "content": PROMPT5})
        knowledge = self.chat(prompt)
        if len(knowledge.split()) < 10:
            print("The knowledge you proposed is too short. Please provide a more detailed knowledge.")
            knowledge = self.learned_knowledge
        prompt.append({"role": "assistant", "content": knowledge})
        self.learned_knowledge = knowledge

        return knowledge

    def prompting_correction(self, transition):
        vehicle_id = transition.vehicle_id
        task = transition.obs['task']
        agent_vehicle_mapping = transition.feedback['agent_vehicle_mapping']

        SYSTEM_MSG = """\
        You are reflecting about your past experience and trying to correct your reasoning, command, and message. Here are some tips for the reflection:
        1. Sign of the Lane ID indicate the direction of the lane.
        2. You and other vehicles may not be able to see all the vehicles in the scene because of possible occlusions.
        3. Your received messages are from other vehicles.
        4. Keep in mind what your task is and complete it as fast as possible.
        """
        SYSTEM_MSG = textwrap.dedent(SYSTEM_MSG)
        if not self.comm_only:
            SYSTEM_MSG += textwrap.dedent("""\
            5. You have a predefined route to follow, so you just need to adjust your speed and follow the route.
            6. Only the following commands are available for agents now: 1. stop (slow down the vehicle) 2. go (keep speed and follow the planned route) 3. speed up (accelerate the vehicle)
            """)
        prompt = [{"role": "system", "content": SYSTEM_MSG}]

        PROMPT1 = f"You are Vehicle {vehicle_id}. Your task was {task} Here is some observations and the actions you took, along with the consequences:"
        PROMPT2 = ""

        # Describe the observation and action
        obs_message = "Observation: "
        obs = transition.obs
        if obs is not None:
            obs_message += get_obs_message(obs)
        else:
            obs_message = ""
        PROMPT2 += obs_message
        
        # Describe the action
        action_message = f"Action: "
        if "reasoning" in transition.action.keys():
            action_message += f"Your reasoning was: {transition.action['reasoning']}\n"
        if "message" in transition.action.keys():
            action_message += f"Your generated message was: {transition.action['message']}\n"
        if "command" in transition.action.keys():
            action_message += f"Your driving command was: {transition.action['command']}\n"
        PROMPT2 += action_message

        # Describing consequences
        if transition.feedback["success"]:
            if self.is_focal:
                consequence = "You successfully completed your task."
            else:
                consequence = "You successfully helped others to complete their task."
        else:
            if transition.feedback["collision_info"]["collision_occurred"]:
                if self.is_focal:
                    reason = transition.feedback["collision_info"]["collision_description"] + f"in {transition.time_to_collision} seconds. Note that you are Vehicle {vehicle_id}. "
                else:
                    reason = transition.feedback["collision_info"]["collision_description"] + f"in {transition.time_to_collision} seconds. You failed at helping them avoid the collision. "
            elif transition.feedback["stagnation_info"]["stagnation_occurred"]:
                if self.is_focal:
                    reason = "You stagnated for too long and ran out of time to complete your task."
                else:
                    reason = "Someone stagnated for too long and ran out of time to complete their task."
            consequence = f"You failed to complete the task, because {reason}"
        PROMPT2 += consequence
        analysis_question = """\
              Analyze the consequences of your action given observations.
              Since your sensors are limited, you may not be able to see all the vehicles in the scene because of possible occlusions, if any, find those occluded vehicles through discussion.
              If there is any vehicle causes potential threat or situation that you think is accident-prone or inefficient (causes stagnation), please provide a reason why you think so.
              """
        prompt.append({"role": "user", "content": PROMPT1 + PROMPT2 + textwrap.dedent(analysis_question)})
        
        # reasoning step
        analysis = self.chat(prompt)
        prompt.append({"role": "assistant", "content": analysis})

        # learning prompt
        if self.comm_only:
            reflection_question = "Given the observation, reasoning, actions you took, and consequences, revise your reasoning and message given the observation."
            json_format = """\
                corrected_reasoning: <reasoning>,
                corrected_message: <message>,
                where you replace <reasoning> with your corrected reasoning,
                replace <message> with your corrected message, 
                Do nothing else but return the corrected_reasoning and corrected_message."""
        elif self.control_only:
            reflection_question = "Given the observation, reasoning, actions you took, and consequences, revise your reasoning and command given the observation."
            json_format = """\
                corrected_reasoning: <reasoning>,
                corrected_command: <command>,
                where you replace <reasoning> with your corrected reasoning,
                replace <command> with your corrected command, 
                Do nothing else but return the corrected_reasoning and corrected_command."""
        else:
            reflection_question = "Given the observation, reasoning, actions you took, and consequences, revise your reasoning, command, and message given the observation."
            json_format = """\
                corrected_reasoning: <reasoning>,
                corrected_command: <command>,
                corrected_message: <message>,
                where you replace <reasoning> with your corrected reasoning,
                replace <command> with your corrected command, 
                replace <message with your corrected messages>.
                Do nothing else but return the corrected_reasoning, corrected_command and corrected_message."""
        json_format = textwrap.dedent(json_format)
        limitation = f"""\
        You MUST respond with a JSON object with the following structure:
        {json_format}
        """
        PROMPT3 = reflection_question + textwrap.dedent(limitation)
        prompt.append({"role": "user", "content": PROMPT3})
        correction = self.chat(prompt)

        return correction

    def parse_correction(self, response, transition):
        try:
            response = re.findall(r"\{[^*]*\}", response)[0]
            response = json.loads(response)
        except:
            response = {}
        if "corrected_reasoning" in response:
            transition.action["reasoning"]= response["corrected_reasoning"]
        if "corrected_message" in response and not self.control_only:
            transition.action["message"] = response["corrected_message"]
        if "corrected_command" in response and not self.comm_only:
            transition.action["command"] = response["corrected_command"]
        return transition

    def debrief(self, batch, collective_knowledges, ego_knowledges, agent_in_debrief, last_round=False):
        # knowledge are in the format [{agent_id: knowledge}], we should convert them to {vehicle_id: knowledge}
        vehicle_id = batch[0].vehicle_id
        task = batch[0].obs['task']
        agent_vehicle_mapping = batch[0].feedback['agent_vehicle_mapping']
        vehicle_in_debrief = []
        for agent_id in agent_in_debrief:
            if agent_id == batch[0].agent_id:
                continue
            vehicle_in_debrief.append(f'Vehicle {agent_vehicle_mapping[agent_id]}')
        vehicle_in_debrief = ', '.join(vehicle_in_debrief)

        SYSTEM_MSG = """\
        You are now in a debriefing session with other vehicles. Here are some tips for the debriefing session:
        1. Sign of the Lane ID indicate the direction of the lane.
        2. You and other vehicles may not be able to see all the vehicles in the scene because of possible occlusions.
        3. Your received messages are from other vehicles.
        4. Keep in mind what your task is and complete it as fast as possible.
        5. Vehicle ID(usually a number) are not allowed to be mentioned in the final strategy and knowledge, but you can describe them without Vehicle ID.
        """
        SYSTEM_MSG = textwrap.dedent(SYSTEM_MSG)
        if not self.comm_only:
            SYSTEM_MSG += textwrap.dedent("""\
            6. You have a predefined route to follow, so you just need to adjust your speed and follow the route.
            7. Only the following commands are available for agents now: 1. stop (slow down the vehicle) 2. go (keep speed and follow the planned route) 3. speed up (accelerate the vehicle)
            """)
        prompt = [{"role": "system", "content": SYSTEM_MSG}]

        PROMPT1 = f"You are Vehicle {vehicle_id}. Your task was {task} Here are some observation and the action you took, and other vehicles' response, along with the consequences:"
        prompt.append({"role": "user", "content": PROMPT1})
        for transition in batch:
            PROMPT2 = ""
            # Describe the observation and action
            obs_message = get_obs_message(transition.obs)
            PROMPT2 += obs_message

            # Describe the action
            action_message = f"Your action: "
            if "reasoning" in transition.action.keys():
                action_message += f"Your reasoning was: {transition.action['reasoning']}\n"
            if "message" in transition.action.keys():
                action_message += f"Your generated message was: {transition.action['message']}\n"
            if "command" in transition.action.keys():
                action_message += f"Your driving command was: {transition.action['command']}\n"
            PROMPT2 += action_message

            # Describing other vehicles' reactions
            if len(transition.other_reactions) > 0:
                other_agents_reactions = "Other Vehicles' Reaction to your action: \n"
                for agent_id, reaction in transition.other_reactions.items():
                    other_vehicle_id = agent_vehicle_mapping[agent_id]
                    if reaction.get("reasoning", None) is not None:
                        other_agents_reactions += f"The other Vehicle {other_vehicle_id}'s reactive reasoning was: {reaction['reasoning']} \n"
                    if reaction.get("command", None) is not None:
                        other_agents_reactions += f"The other Vehicle {other_vehicle_id}'s reactive command was: {reaction['command']} \n"
                    if reaction.get("message", None) is not None:
                        other_agents_reactions += f"The other Vehicle {other_vehicle_id}'s reactive message was: {reaction['message']} \n"
                PROMPT2 += other_agents_reactions
                # separator
                PROMPT2 += delimiter

            # Describing consequences
            if transition.feedback["success"]:
                if self.is_focal:
                    consequence = "You successfully completed your task."
                else:
                    consequence = "You successfully helped others to complete their task."
            else:
                if transition.feedback["collision_info"]["collision_occurred"]:
                    if self.is_focal:
                        reason = transition.feedback["collision_info"]["collision_description"] + f"in {transition.time_to_collision} seconds. Note that you are Vehicle {vehicle_id}. "
                    else:
                        reason = transition.feedback["collision_info"]["collision_description"] + f"in {transition.time_to_collision} seconds. You failed at helping them avoid the collision. "
                elif transition.feedback["stagnation_info"]["stagnation_occurred"]:
                    if self.is_focal:
                        reason = transition.feedback["stagnation_info"].get("stagnation_description", "You stagnated for too long and ran out of time to complete your task.")
                    else:
                        reason = "Someone stagnated for too long and ran out of time to complete their task."
                consequence = f"You failed to complete the task, because {reason}"
            PROMPT2 += consequence
            prompt.append({"role": "user", "content": PROMPT2})
        analysis_question = """\
              Analyze the consequences of your action and other vehicles' reactions.
              Since your sensors are limited, you may not be able to see all the vehicles in the scene because of possible occlusions, if any, find those occluded vehicles through discussion.
              If there is any vehicle causes potential threat or situation that you think is accident-prone or inefficient (causes stagnation), please provide a reason why you think so.
              If your task conflicts with other vehicles' tasks, discuss the conflict and propose a solution.
              If you successfully completed the task, discuss what makes the task successful.
              """
        prompt.append({"role": "user", "content": textwrap.dedent(analysis_question)})
        analysis = self.chat(prompt)
        prompt.append({"role": "assistant", "content": analysis})

        # join the turn-based debrief discussion
        if self.cooperative_knowledge:
            prompt.append({"role": "user", "content": f"In this episode, you were following the cooperative strategy/plan: {self.cooperative_knowledge}. \n"})
        if len(collective_knowledges)>0:
            PROMPT4 = f"Here is your discussion dialogue with other cooperative vehicles ({vehicle_in_debrief}) so far: \n"
            for collective_knowledge in collective_knowledges:
                for agent_id, knowledge in collective_knowledge.items():
                    discussion_vehicle_id = agent_vehicle_mapping[agent_id]
                    if discussion_vehicle_id == vehicle_id:
                        PROMPT4 += f"Vehicle {discussion_vehicle_id}(yourself): {knowledge} \n"
                    else:
                        PROMPT4 += f"Vehicle {discussion_vehicle_id}: {knowledge} \n"
            prompt.append({"role": "user", "content": PROMPT4})
            if last_round:
                cooperation_question = f"Summarize and revise the cooperative strategy/plan for all vehicles (including yourself and {vehicle_in_debrief}) in one paragraph based on your analysis. You should try to collaborate with others if conflict arises.\n"
            else:
                cooperation_question= f"Revise or add to the cooperative streategy/plan for all vehicles (including yourself and {vehicle_in_debrief}) in one paragraph based on your analysis. You should try to collaborate with others if conflict arises.\n"
        else:
            cooperation_question = f"Propose a coordination strategy/plan for your next interactions with all cooperative vehicle (including yourself and {vehicle_in_debrief}) in one paragraph. You should try to collaborate with others if conflict arises.\n"
        cooperation_question += f"Note that you are not allowed to mention Vehicle ID and their model in the strategy because they can change, but you can describe them. And your task is {task}"
        prompt.append({"role": "user", "content": cooperation_question})
        strategy = self.chat(prompt)
        if len(strategy.split()) < 10:
            print("The strategy you proposed is too short. Please provide a more detailed strategy.")
            strategy = self.cooperative_knowledge # use the previous strategy
        prompt.append({"role": "assistant", "content": strategy})
        if last_round:
            self.cooperative_knowledge = strategy

        # learning knowledge for yourself
        if self.learned_knowledge:
            PROMPT5 = f"Here was the knowledge about completing your task from your past experience: {self.learned_knowledge} \n"
            prompt.append({"role": "user", "content": PROMPT5})
            knowledge_question = "Revise the knowledge for yourself to keep in mind for future driving, based on your analysis and proposed strategy. \n"
        else:
            knowledge_question = f"Based on your analysis and the cooperation strategy, summarize a knowledge for yourself to keep in mind for executing the task: {task}\n"
        knowledge_question += "Note that you are not allowed to mention Vehicle ID and their model in the knoweldge because they can change, but you can describe them. "
        prompt.append({"role": "user", "content": knowledge_question})
        knowledge = self.chat(prompt)
        if len(knowledge.split()) < 10:
            print("The knowledge you proposed is too short. Please provide a more detailed knowledge.")
            knowledge = self.learned_knowledge # use the previous knowledge
        prompt.append({"role": "assistant", "content": knowledge})
        if last_round:
            self.learned_knowledge = knowledge
        print(prompt)
        return strategy, knowledge

def get_obs_message(observation):
    # Prompt Observation
    if observation is None:
        obs_message = "No observation"
    else:
        obs_message = "Observation: "
        for obs_class, obs in observation.items():
            if "observation" in obs_class:
                obs_message += obs + " "
            elif "message" in obs_class:
                obs_message += "\n-------------------\n"
                obs_message += "Start of Message Dialog: "
                for recent_message in obs:
                    message = recent_message['message']
                    sender = recent_message['sender']
                    age_seconds = recent_message['age_seconds']
                    
                    obs_message += f"(Message from {sender}, {age_seconds} seconds ago: {message}) \n"
                obs_message += "\n-------------------\n"
                obs_message += "End of Message Dialog. Only vehicles in the dialog are able to communicate with each other. "
            elif "task" in obs_class:
                obs_message += f"Your task is: {obs} \n"
            else:
                obs_message += f"{obs_class}: {obs} "
    return obs_message